# Copyright (C) 2025 Intel Corporation
#
# SPDX-License-Identifier: MIT

import os.path as osp
import re
from functools import cached_property
from typing import Dict, Generator, List, Optional, Set, Type, Union

from datumaro.components.annotation import (
    NO_OBJECT_ID,
    AnnotationType,
    Bbox,
    Caption,
    Cuboid2D,
    Cuboid3d,
    Ellipse,
    GroupType,
    Label,
    LabelCategories,
    MaskCategories,
    Points,
    PointsCategories,
    Polygon,
    PolyLine,
    RleMask,
    Skeleton,
)
from datumaro.components.dataset_base import DatasetItem, SubsetBase
from datumaro.components.errors import DatasetImportError, InvalidAnnotationError, MediaTypeError
from datumaro.components.importer import ImportContext
from datumaro.components.media import Image, MediaElement, MediaType, PointCloud, Video, VideoFrame
from datumaro.plugins.data_formats.datumaro.page_mapper import DatumPageMapper
from datumaro.util import parse_json_file, take_by
from datumaro.version import __version__

from .format import DATUMARO_FORMAT_VERSION, DatumaroPath

__all__ = ["DatumaroBase"]

LEGACY_VERSION = "legacy"


class JsonReader:
    def __init__(
        self,
        path: str,
        subset: str,
        rootpath: str,
        images_dir: str,
        pcd_dir: str,
        video_dir: str,
        ctx: ImportContext,
    ) -> None:
        self._subset = subset
        self._rootpath = rootpath
        self._images_dir = images_dir
        self._pcd_dir = pcd_dir
        self._video_dir = video_dir
        self._videos = {}
        self._ctx = ctx
        self.ann_types = {}

        self._reader = self._init_reader(path)
        self.media_type = self._load_media_type(self._reader)
        self.dm_format_version = self._load_dm_format_version(self._reader)
        self.infos = self._load_infos(self._reader)
        self.categories = self._load_categories(self._reader)
        self.items = self._load_items(self._reader)

    @cached_property
    def _legacy_related_images_dir(self) -> str:
        if self._rootpath:
            related_images_dir = osp.join(self._rootpath, DatumaroPath.LEGACY_RELATED_IMAGES_DIR)
            if osp.isdir(related_images_dir):
                return related_images_dir
        return ""

    def _init_reader(self, path: str):
        return parse_json_file(path)

    @staticmethod
    def _load_media_type(parsed) -> Type[MediaElement]:
        media_type = parsed.get("media_type", MediaType.IMAGE)
        return MediaType(media_type).media

    @staticmethod
    def _load_dm_format_version(parsed) -> str:
        return parsed.get("dm_format_version", LEGACY_VERSION)

    @staticmethod
    def _load_ann_types(parsed) -> Set[AnnotationType]:
        return parsed.get("ann_types", set())

    @staticmethod
    def _load_infos(parsed) -> Dict:
        return parsed.get("infos", {})

    @staticmethod
    def _load_categories(parsed) -> Dict:
        categories = {}

        parsed_label_cat = parsed["categories"].get(AnnotationType.label.name)
        if parsed_label_cat:
            label_categories = LabelCategories(attributes=parsed_label_cat.get("attributes", []))
            for item in parsed_label_cat["labels"]:
                label_categories.add(
                    item["name"],
                    parent=item["parent"],
                    attributes=item.get("attributes", []),
                )

            for item in parsed_label_cat.get("label_groups", []):
                label_categories.add_label_group(
                    item["name"],
                    labels=item["labels"],
                    group_type=GroupType.from_str(item["group_type"]),
                )

            categories[AnnotationType.label] = label_categories

        parsed_mask_cat = parsed["categories"].get(AnnotationType.mask.name)
        if parsed_mask_cat:
            colormap = {}
            for item in parsed_mask_cat["colormap"]:
                colormap[int(item["label_id"])] = (item["r"], item["g"], item["b"])

            mask_categories = MaskCategories(colormap=colormap)
            categories[AnnotationType.mask] = mask_categories

        parsed_points_cat = parsed["categories"].get(AnnotationType.points.name)
        if parsed_points_cat:
            point_categories = PointsCategories()
            for item in parsed_points_cat["items"]:
                point_categories.add(
                    int(item["label_id"]),
                    item["labels"],
                    joints=item["joints"],
                    positions=item.get("positions"),
                )

            categories[AnnotationType.points] = point_categories

        return categories

    def _load_items(self, parsed) -> List[DatasetItem]:
        item_descs: List = parsed["items"]
        pbar = self._ctx.progress_reporter

        def _gen():
            # Reverse the list to pop from the front of it
            item_descs.reverse()
            while item_descs:
                yield item_descs.pop()

        items = []
        ann_types = set()
        actual_media_types = set()
        for item_desc in pbar.iter(
            _gen(), desc=f"Importing '{self._subset}'", total=len(item_descs)
        ):
            item = self._parse_item(item_desc)
            if item is not None:
                items.append(item)
                for ann in item.annotations:
                    ann_types.add(ann.type)
                if item.media:
                    actual_media_types.add(item.media.type)

        self.ann_types = ann_types

        if len(actual_media_types) == 1:
            actual_media_type = actual_media_types.pop()
        elif len(actual_media_types) > 1:
            actual_media_type = MediaType.MEDIA_ELEMENT
        else:
            actual_media_type = None

        if actual_media_type and self.dm_format_version == LEGACY_VERSION:
            self.media_type = actual_media_type.media

        if actual_media_type and not issubclass(actual_media_type.media, self.media_type):
            raise MediaTypeError(
                f"Unexpected media type of a dataset '{self.media_type}'. "
                f"Expected media type is '{actual_media_type.media}."
            )

        return items

    def _parse_item(self, item_desc: Dict) -> Optional[DatasetItem]:
        STR_MULTIPLE_MEDIA = "DatasetItem cannot contain multiple media types"
        try:
            item_id = item_desc["id"]

            media = None
            image_info = item_desc.get("image")
            if image_info:
                image_filename = image_info.get("path") or item_id + DatumaroPath.IMAGE_EXT
                image_path = osp.join(self._images_dir, self._subset, image_filename)
                if not osp.isfile(image_path):
                    # backward compatibility
                    old_image_path = osp.join(self._images_dir, image_filename)
                    if osp.isfile(old_image_path):
                        image_path = old_image_path

                media = Image.from_file(path=image_path, size=image_info.get("size"))

            pcd_info = item_desc.get("point_cloud")
            if media and pcd_info:
                raise MediaTypeError(STR_MULTIPLE_MEDIA)
            if pcd_info and (pcd_path := pcd_info.get("path")):
                point_cloud = osp.join(self._pcd_dir, self._subset, pcd_path)

                def make_related_image_path(ri_path):
                    if self.dm_format_version == LEGACY_VERSION:
                        return osp.join(
                            self._legacy_related_images_dir, self._subset, item_id, ri_path
                        )
                    return osp.join(self._images_dir, self._subset, ri_path)

                related_images = None
                ri_info = item_desc.get("related_images")
                if ri_info:
                    related_images = [
                        Image.from_file(
                            size=ri.get("size"),
                            path=make_related_image_path(ri.get("path")),
                        )
                        for ri in ri_info
                    ]

                media = PointCloud.from_file(path=point_cloud, extra_images=related_images)

            video_frame_info = item_desc.get("video_frame")
            if media and video_frame_info:
                raise MediaTypeError(STR_MULTIPLE_MEDIA)
            if video_frame_info:
                video_path = osp.join(
                    self._video_dir, self._subset, video_frame_info.get("video_path")
                )
                if video_path not in self._videos:
                    self._videos[video_path] = Video(video_path)
                video = self._videos[video_path]

                frame_index = video_frame_info.get("frame_index")

                media = VideoFrame(video, frame_index)

            video_info = item_desc.get("video")
            if media and video_info:
                raise MediaTypeError(STR_MULTIPLE_MEDIA)
            if video_info:
                video_path = osp.join(self._video_dir, self._subset, video_info.get("path"))
                if video_path not in self._videos:
                    self._videos[video_path] = Video(video_path)
                step = video_info.get("step", 1)
                start_frame = video_info.get("start_frame", 0)
                end_frame = video_info.get("end_frame", None)
                media = Video(
                    path=video_path, step=step, start_frame=start_frame, end_frame=end_frame
                )

            media_desc = item_desc.get("media")
            if not media and media_desc and media_desc.get("path"):
                media = MediaElement(path=media_desc.get("path"))

        except Exception as e:
            self._ctx.error_policy.report_item_error(
                e, item_id=(item_desc.get("id", None), self._subset)
            )
            return None

        return DatasetItem(
            id=item_id,
            subset=self._subset,
            annotations=lambda: self._load_annotations(item_desc),
            media=media,
            attributes=item_desc.get("attr"),
        )

    def _load_annotations(self, item: Dict):
        loaded = []

        for ann in item.get("annotations", []):
            try:
                ann_id = ann.get("id")
                ann_type = AnnotationType[ann["type"]]
                attributes = ann.get("attributes")
                group = ann.get("group")
                object_id = ann.get("object_id", NO_OBJECT_ID)

                label_id = ann.get("label_id")
                z_order = ann.get("z_order")
                points = ann.get("points")

                if ann_type == AnnotationType.label:
                    loaded.append(
                        Label(
                            label=label_id,
                            id=ann_id,
                            attributes=attributes,
                            group=group,
                            object_id=object_id,
                        )
                    )

                elif ann_type == AnnotationType.mask:
                    rle = ann["rle"]
                    rle["counts"] = rle["counts"].encode("ascii")
                    loaded.append(
                        RleMask(
                            rle=rle,
                            label=label_id,
                            id=ann_id,
                            attributes=attributes,
                            group=group,
                            object_id=object_id,
                            z_order=z_order,
                        )
                    )

                elif ann_type == AnnotationType.polyline:
                    loaded.append(
                        PolyLine(
                            points,
                            label=label_id,
                            id=ann_id,
                            attributes=attributes,
                            group=group,
                            object_id=object_id,
                            z_order=z_order,
                        )
                    )

                elif ann_type == AnnotationType.polygon:
                    loaded.append(
                        Polygon(
                            points,
                            label=label_id,
                            id=ann_id,
                            attributes=attributes,
                            group=group,
                            object_id=object_id,
                            z_order=z_order,
                        )
                    )

                elif ann_type == AnnotationType.bbox:
                    x, y, w, h = ann["bbox"]
                    loaded.append(
                        Bbox(
                            x,
                            y,
                            w,
                            h,
                            label=label_id,
                            id=ann_id,
                            attributes=attributes,
                            group=group,
                            object_id=object_id,
                            z_order=z_order,
                        )
                    )

                elif ann_type == AnnotationType.points:
                    loaded.append(
                        Points(
                            points,
                            label=label_id,
                            id=ann_id,
                            visibility=ann.get("visibility"),
                            attributes=attributes,
                            group=group,
                            object_id=object_id,
                            z_order=z_order,
                        )
                    )

                elif ann_type == AnnotationType.caption:
                    caption = ann.get("caption")
                    loaded.append(Caption(caption, id=ann_id, attributes=attributes, group=group))

                elif ann_type == AnnotationType.cuboid_3d:
                    loaded.append(
                        Cuboid3d(
                            ann.get("position"),
                            ann.get("rotation"),
                            ann.get("scale"),
                            label=label_id,
                            id=ann_id,
                            attributes=attributes,
                            group=group,
                            object_id=object_id,
                        )
                    )

                elif ann_type == AnnotationType.ellipse:
                    loaded.append(
                        Ellipse(
                            *points,
                            label=label_id,
                            id=ann_id,
                            attributes=attributes,
                            group=group,
                            object_id=object_id,
                            z_order=z_order,
                        )
                    )

                elif ann_type == AnnotationType.hash_key:
                    continue
                elif ann_type == AnnotationType.cuboid_2d:
                    loaded.append(
                        Cuboid2D(
                            list(map(tuple, points)),
                            label=label_id,
                            id=ann_id,
                            attributes=attributes,
                            group=group,
                            object_id=object_id,
                            z_order=z_order,
                        )
                    )
                elif ann_type == AnnotationType.skeleton:
                    loaded.append(
                        Skeleton(
                            elements=self._load_skeleton_elements_annotations(
                                ann, label_id, points
                            ),
                            label=label_id,
                            id=ann_id,
                            attributes=attributes,
                            group=group,
                            z_order=z_order,
                        )
                    )
                else:
                    raise NotImplementedError()
            except Exception as e:
                self._ctx.error_policy.report_annotation_error(
                    e, item_id=(ann.get("id", None), self._subset)
                )

        return loaded

    def _load_skeleton_elements_annotations(
        self, ann: dict, label_id: int, points: List[Union[float, int]]
    ) -> List[Points]:
        if len(points) % 3 != 0:
            raise InvalidAnnotationError(
                f"Points have invalid value count {len(points)}, "
                "which is not divisible by 3. Expected (x, y, visibility) triplets."
            )
        points_attributes = ann.get("points_attributes")
        if len(points) != len(points_attributes) * 3:
            raise InvalidAnnotationError(
                f"Points and Points_attributes lengths ({len(points)}, {len(points_attributes)}) "
                f"do not match, for each triplet (x, y, visibility) in points "
                f"there should be one dict in points_attributes."
            )

        label_category = self.categories[AnnotationType.label]
        sub_labels = self.categories[AnnotationType.points].items[label_id].labels
        return [
            Points(
                points=[x, y],
                visibility=[v],
                label=label_category.find(
                    name=sub_label, parent=label_category.items[label_id].name
                )[0],
                attributes=attrs,
            )
            for (x, y, v), sub_label, attrs in zip(
                take_by(points, 3), sub_labels, points_attributes
            )
        ]

    def __len__(self):
        return len(self.items)

    def __iter__(self) -> Generator[DatasetItem, None, None]:
        yield from self.items


class StreamJsonReader(JsonReader):
    def __init__(
        self,
        path: str,
        subset: str,
        rootpath: str,
        images_dir: str,
        pcd_dir: str,
        video_dir: str,
        ctx: ImportContext,
    ) -> None:
        super().__init__(path, subset, rootpath, images_dir, pcd_dir, video_dir, ctx)
        self._length = None
        if self.dm_format_version == LEGACY_VERSION:
            try:
                self.media_type = next(item.media.type.media for item in iter(self) if item.media)
            except StopIteration:
                pass

    def __len__(self):
        return len(self._reader)

    def __iter__(self) -> Generator[DatasetItem, None, None]:
        ann_types = set()
        pbar = self._ctx.progress_reporter
        for item_desc in pbar.iter(
            self._reader,
            desc=f"Importing '{self._subset}'",
        ):
            item = self._parse_item(item_desc)
            yield item

            if item is not None and item.annotations_are_initialized:
                for ann in item.annotations:
                    ann_types.add(ann.type)
        self.ann_types = ann_types

    def _init_reader(self, path: str) -> DatumPageMapper:
        return DatumPageMapper(path)

    @staticmethod
    def _load_media_type(page_mapper: DatumPageMapper) -> Type[MediaElement]:
        media_type = page_mapper.media_type

        if media_type is None:
            return MediaType.IMAGE.media

        return media_type.media

    @staticmethod
    def _load_ann_types(page_mapper: DatumPageMapper) -> Set[AnnotationType]:
        ann_types = page_mapper.ann_types

        if ann_types is None:
            return set()

        return ann_types

    @staticmethod
    def _load_infos(page_mapper: DatumPageMapper) -> Dict:
        return page_mapper.infos

    @staticmethod
    def _load_categories(page_mapper: DatumPageMapper) -> Dict:
        return JsonReader._load_categories({"categories": page_mapper.categories})

    def _load_items(self, parsed) -> List:
        return []

    @staticmethod
    def _load_dm_format_version(page_mapper: DatumPageMapper) -> str:
        return page_mapper.dm_format_version or LEGACY_VERSION


class DatumaroBase(SubsetBase):
    CURRENT_DATUMARO_FORMAT_VERSION = DATUMARO_FORMAT_VERSION

    # If Datumaro format version goes up, it will be
    # ALLOWED_VERSIONS = {LEGACY_VERSION, 1.0, ..., CURRENT_DATUMARO_FORMAT_VERSION}
    ALLOWED_VERSIONS = {LEGACY_VERSION, CURRENT_DATUMARO_FORMAT_VERSION}

    def __init__(
        self,
        path: str,
        *,
        subset: Optional[str] = None,
        stream: bool = False,
        ctx: Optional[ImportContext] = None,
    ):
        assert osp.isfile(path), path
        subset = osp.splitext(osp.basename(path))[0] if subset is None else subset
        self._stream = stream
        super().__init__(subset=subset, ctx=ctx)
        self._init_path(path)

        dm_version = self._get_dm_format_version(path)

        # when backward compatibility happen, we should implement version specific readers
        if dm_version not in self.ALLOWED_VERSIONS:
            raise DatasetImportError(
                f"Datumaro format version of the given dataset is {dm_version}, "
                f"but not supported by this Datumaro version: {__version__}. "
                f"The allowed datumaro format versions are {self.ALLOWED_VERSIONS}. "
                "Please install the latest Datumaro."
            )

        self._load_impl(path)

    def _init_path(self, path: str):
        """Initialize path variables"""
        rootpath = ""
        if path.endswith(osp.join(DatumaroPath.ANNOTATIONS_DIR, osp.basename(path))):
            rootpath = path.rsplit(DatumaroPath.ANNOTATIONS_DIR, maxsplit=1)[0]
        self._rootpath = rootpath

        images_dir = ""
        if rootpath and osp.isdir(osp.join(rootpath, DatumaroPath.IMAGES_DIR)):
            images_dir = osp.join(rootpath, DatumaroPath.IMAGES_DIR)
        self._images_dir = images_dir

        pcd_dir = ""
        if rootpath and osp.isdir(osp.join(rootpath, DatumaroPath.PCD_DIR)):
            pcd_dir = osp.join(rootpath, DatumaroPath.PCD_DIR)
        self._pcd_dir = pcd_dir

        video_dir = ""
        if rootpath and osp.isdir(osp.join(rootpath, DatumaroPath.VIDEO_DIR)):
            video_dir = osp.join(rootpath, DatumaroPath.VIDEO_DIR)
        self._video_dir = video_dir

    @property
    def is_stream(self) -> bool:
        return self._stream

    def infos(self):
        return self._reader.infos

    def categories(self):
        return self._reader.categories

    def media_type(self):
        return self._reader.media_type

    def ann_types(self):
        return self._reader.ann_types

    def _load_impl(self, path: str) -> None:
        """Actual implementation of loading Datumaro format."""
        self._reader = (
            JsonReader(
                path,
                self._subset,
                self._rootpath,
                self._images_dir,
                self._pcd_dir,
                self._video_dir,
                self._ctx,
            )
            if not self._stream
            else StreamJsonReader(
                path,
                self._subset,
                self._rootpath,
                self._images_dir,
                self._pcd_dir,
                self._video_dir,
                self._ctx,
            )
        )
        return self._reader

    def _get_dm_format_version(self, path) -> str:
        """
        Get Datumaro format at exporting the dataset

        Note that the legacy Datumaro doesn't store the version into exported dataset.
        Thus it returns DatumaroBase.LEGACY_VERSION
        """
        # We can assume that the version information will be within the first 1 KB of the file.
        # This is because we are the producer of Datumaro format.
        search_size = 1024  # 1 KB

        pattern = '"dm_format_version"\s*:\s*"(\w+)"'

        with open(path, "r", encoding="utf-8") as fp:
            out = fp.read(search_size)
            found = re.search(pattern, out)

            if not found:
                return LEGACY_VERSION

        version_str = found.group(1)

        if re.match("\d+\.d+", version_str) is None:
            raise DatasetImportError(f"Invalid version string: {version_str} ")

        return version_str

    def __len__(self) -> int:
        return len(self._reader)

    def __iter__(self) -> DatasetItem:
        yield from self._reader
